import argparse
import os
import shutil
 
 
def split_dataset(root_dir, output_dir_A, output_dir_B, k, k_factor=2, data_type="all", no_first=False):
   
    for first_level in sorted(os.listdir(root_dir)):
        first_level_path = os.path.join(root_dir, first_level)
        if not os.path.isdir(first_level_path):
            continue

        print(f"Processing: {first_level}")
        second_level_list = sorted(os.listdir(first_level_path))
        if no_first:
            print("Remove first: ", second_level_list[0])
            second_level_list = second_level_list[1:]
            
        
        for second_level in second_level_list:
            second_level_path = os.path.join(first_level_path, second_level)
            if not os.path.isdir(second_level_path):
                # copy yaml file
                target_A = os.path.join(output_dir_A, first_level)
                target_B = os.path.join(output_dir_B, first_level)
                os.makedirs(target_A, exist_ok=True)
                os.makedirs(target_B, exist_ok=True)
                shutil.copy2(second_level_path, target_A)
                shutil.copy2(second_level_path, target_B)
                continue


            file_list = []
            for file in os.listdir(second_level_path):
                if data_type == "all":
                    file_list.append(file)
                elif  data_type == "heter" and file.endswith(".pcd"):
                    file_list.append(file)
                elif data_type == "camera" and file.endswith(".png"):
                    file_list.append(file)

            
            assert(len(file_list) % k_factor == 0)
            select_k = k * k_factor 
            file_list = sorted(file_list)
            dest_a_path = os.path.join(output_dir_A, first_level, second_level)
            dest_b_path = os.path.join(output_dir_B, first_level, second_level)
            os.makedirs(dest_a_path, exist_ok=True)
            os.makedirs(dest_b_path, exist_ok=True)
            for file in file_list[:select_k]:
                src_file = os.path.join(second_level_path, file)
                dst_file = os.path.join(dest_a_path, file)
                
                shutil.copy2(src_file, dst_file)
            
            for file in file_list[select_k:]:
                
                src_file = os.path.join(second_level_path, file)
                dst_file = os.path.join(dest_b_path, file)
                shutil.copy2(src_file, dst_file)
            





if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, required=True, help="origin dir")
    parser.add_argument("--output_dir", type=str, required=True, help="output dir")
    parser.add_argument("--k", type=int, required=True, help="number of shots")
    args = parser.parse_args()
    root_directory = args.data_dir
    base_output_dir = args.output_dir

    

    data_factor_dict = {"all": 6, "heter": 2, "camera": 5}
    data_type = "all"
    k_factor = data_factor_dict[data_type]

    shot_list = [args.k]
    for i in shot_list:  
        output_A = os.path.join(base_output_dir, f"OPV2V_{i}/train")
        output_B = os.path.join(base_output_dir, f"OPV2V_{i}/test")

        os.makedirs(output_A, exist_ok=True)
        os.makedirs(output_B, exist_ok=True)

        print(f"start split{output_A}, {output_B}")
        split_dataset(root_directory, output_A, output_B, i, data_factor_dict[data_type], data_type, True)
        print(f"split done：{output_A}, {output_B}")

